{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ground'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一个格子的状态\n",
    "def get_state(row, col):\n",
    "    if row != 3:\n",
    "        return 'ground'\n",
    "\n",
    "    if row == 3 and col == 0:\n",
    "        return 'ground'\n",
    "\n",
    "    if row == 3 and col == 11:\n",
    "        return 'terminal'\n",
    "\n",
    "    return 'trap'\n",
    "\n",
    "\n",
    "get_state(0, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 1, -1)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#在一个格子里做一个动作\n",
    "def move(row, col, action):\n",
    "    #如果当前已经在陷阱或者终点，则不能执行任何动作\n",
    "    if get_state(row, col) in ['trap', 'terminal']:\n",
    "        return row, col, 0\n",
    "\n",
    "    #↑\n",
    "    if action == 0:\n",
    "        row -= 1\n",
    "\n",
    "    #↓\n",
    "    if action == 1:\n",
    "        row += 1\n",
    "\n",
    "    #←\n",
    "    if action == 2:\n",
    "        col -= 1\n",
    "\n",
    "    #→\n",
    "    if action == 3:\n",
    "        col += 1\n",
    "\n",
    "    #不允许走到地图外面去\n",
    "    row = max(0, row)\n",
    "    row = min(3, row)\n",
    "    col = max(0, col)\n",
    "    col = min(11, col)\n",
    "\n",
    "    #是陷阱的话，奖励是-100，否则都是-1\n",
    "    reward = -1\n",
    "    if get_state(row, col) == 'trap':\n",
    "        reward = -100\n",
    "\n",
    "    return row, col, reward\n",
    "\n",
    "\n",
    "move(0, 0, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 12, 4)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n",
    "Q = np.zeros([4, 12, 4])\n",
    "\n",
    "Q.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#根据状态选择一个动作\n",
    "def get_action(row, col):\n",
    "    #有小概率选择随机动作\n",
    "    if random.random() < 0.1:\n",
    "        return random.choice(range(4))\n",
    "\n",
    "    #否则选择分数最高的动作\n",
    "    return Q[row, col].argmax()\n",
    "\n",
    "\n",
    "get_action(0, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_update(row, col, action, reward, next_row, next_col):\n",
    "    #target为下一个格子的最高分数，这里的计算和下一步的动作无关\n",
    "    target = 0.9 * Q[next_row, next_col].max()\n",
    "    #加上本步的分数\n",
    "    target += reward\n",
    "\n",
    "    #value为当前state和action的分数\n",
    "    value = Q[row, col, action]\n",
    "\n",
    "    #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n",
    "    #此处是求两者的差,越接近0越好\n",
    "    update = target - value\n",
    "\n",
    "    #这个0.1相当于lr\n",
    "    update *= 0.1\n",
    "\n",
    "    return update\n",
    "\n",
    "\n",
    "get_update(0, 0, 3, -1, 0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -118\n",
      "100 -49\n",
      "200 -31\n",
      "300 -20\n",
      "400 -109\n",
      "500 -102\n",
      "600 -12\n",
      "700 -13\n",
      "800 -18\n",
      "900 -13\n",
      "1000 -12\n",
      "1100 -13\n",
      "1200 -15\n",
      "1300 -14\n",
      "1400 -105\n"
     ]
    }
   ],
   "source": [
    "#训练\n",
    "def train():\n",
    "    for epoch in range(1500):\n",
    "        #初始化当前位置\n",
    "        row = random.choice(range(4))\n",
    "        col = 0\n",
    "\n",
    "        #初始化第一个动作\n",
    "        action = get_action(row, col)\n",
    "\n",
    "        #计算反馈的和，这个数字应该越来越小\n",
    "        reward_sum = 0\n",
    "\n",
    "        #循环直到到达终点或者掉进陷阱\n",
    "        while get_state(row, col) not in ['terminal', 'trap']:\n",
    "\n",
    "            #执行动作\n",
    "            next_row, next_col, reward = move(row, col, action)\n",
    "            reward_sum += reward\n",
    "\n",
    "            #求新位置的动作\n",
    "            next_action = get_action(next_row, next_col)\n",
    "\n",
    "            #计算分数\n",
    "            update = get_update(row, col, action, reward, next_row, next_col)\n",
    "\n",
    "            #更新分数\n",
    "            Q[row, col, action] += update\n",
    "\n",
    "            #更新当前位置\n",
    "            row = next_row\n",
    "            col = next_col\n",
    "            action = next_action\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            print(epoch, reward_sum)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "□□□□□□□□□□□□\n",
      "□↑□□□□□□□□□□\n",
      "□□□□□□□□□□□□\n",
      "□○○○○○○○○○○❤\n"
     ]
    }
   ],
   "source": [
    "#打印游戏，方便测试\n",
    "def show(row, col, action):\n",
    "    graph = [\n",
    "        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n",
    "        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n",
    "        '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n",
    "        '○', '○', '○', '○', '○', '❤'\n",
    "    ]\n",
    "\n",
    "    action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n",
    "\n",
    "    graph[row * 12 + col] = action\n",
    "\n",
    "    graph = ''.join(graph)\n",
    "\n",
    "    for i in range(0, 4 * 12, 12):\n",
    "        print(graph[i:i + 12])\n",
    "\n",
    "\n",
    "show(1, 1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "□□□□□□□□□□□□\n",
      "□□□□□□□□□□□□\n",
      "□□□□□□□□□□□↓\n",
      "□○○○○○○○○○○❤\n"
     ]
    }
   ],
   "source": [
    "from IPython import display\n",
    "import time\n",
    "\n",
    "\n",
    "def test():\n",
    "    #起点\n",
    "    row = random.choice(range(4))\n",
    "    col = 0\n",
    "\n",
    "    #最多玩N步\n",
    "    for _ in range(200):\n",
    "\n",
    "        #获取当前状态，如果状态是终点或者掉陷阱则终止\n",
    "        if get_state(row, col) in ['trap', 'terminal']:\n",
    "            break\n",
    "\n",
    "        #选择最优动作\n",
    "        action = Q[row, col].argmax()\n",
    "\n",
    "        #打印这个动作\n",
    "        display.clear_output(wait=True)\n",
    "        time.sleep(0.1)\n",
    "        show(row, col, action)\n",
    "\n",
    "        #执行动作\n",
    "        row, col, reward = move(row, col, action)\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "→→→→→→↓→→↓→↓\n",
      "↓→→↓→→→→→↓→↓\n",
      "→→→→→→→→→→→↓\n",
      "↑↑↑↑↑↑↑↑↑↑↑↑\n"
     ]
    }
   ],
   "source": [
    "#打印所有格子的动作倾向\n",
    "for row in range(4):\n",
    "    line = ''\n",
    "    for col in range(12):\n",
    "        action = Q[row, col].argmax()\n",
    "        action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n",
    "        line += action\n",
    "    print(line)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第5章-时序差分算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
